Summary

Background:

This notebook:

Next steps:

Exploratory data analysis

library(covidcast)
library(dplyr)
library(tidyr)
library(tibble)
library(ggplot2)
library(Matrix)

source('common_params.R')
aheads=ahead;rm(ahead)
flu_hosp_prop_data = covidcast_signal(data_source = hhs_source,
                            signal = flu_hosp_prop,
                            start_day = start_day,
                            end_day = end_day,
                            geo_type = geo_type
    ) %>% tibble (
    ) %>% filter (
      geo_value %in% states_dc_pr_vi,
    ) %>% select (
      geo_value,
      time_value,
      value,
    ) %>% mutate (
      signal='flu_hosp_prop',
    )

covid_hosp_prop_data = covidcast_signal(data_source = hhs_source,
                            signal = covid_hosp_prop,
                            start_day = start_day,
                            end_day = end_day,
                            geo_type = geo_type
    ) %>% tibble (
    ) %>% filter (
      geo_value %in% states_dc_pr_vi,
    ) %>% select (
      geo_value,
      time_value,
      value,
    ) %>% mutate (
      signal='covid_hosp_prop',
    )

covid_case_prop_data = covidcast_signal(data_source = jhu_source,
                            signal = covid_case_prop,
                            start_day = start_day,
                            end_day = end_day,
                            geo_type = geo_type
    ) %>% tibble (
    ) %>% filter (
      geo_value %in% states_dc_pr_vi,
    ) %>% select (
      geo_value,
      time_value,
      value,
    ) %>% mutate (
      signal='covid_case_prop',
    )

data_tall = bind_rows(
      flu_hosp_prop_data,
      covid_hosp_prop_data,
      covid_case_prop_data,
    )

data_wide = flu_hosp_prop_data %>% select (
        geo_value,
        time_value,
        value_flu_hosp_prop = value,
      ) %>% inner_join (
        covid_hosp_prop_data %>% select (
          geo_value,
          time_value,
          value_covid_hosp_prop = value,
        ),
      by=c('geo_value', 'time_value')
    ) %>% inner_join(
        covid_case_prop_data %>% select (
          geo_value,
          time_value,
          value_covid_case_prop = value,
        ),
      by=c('geo_value', 'time_value')
    )

In the following figure, we restrict the range of the flu hospitalization rate to be within 0 and 0.2, to improve the clarity of the plot. In a very few cases, there are abnormally large or negative values; I believe these to be data reporting errors. Please refer to the appendix.

plt = ggplot(data_tall %>% filter (
        ((signal == 'flu_hosp_prop') & (value <= 0.20) & (value >= 0)) | 
          (signal != 'flu_hosp_prop'),
      ),
             aes(x=time_value,
                 y=value,
                 color=geo_value),
    ) + geom_line (
    ) + facet_wrap (
      vars(signal),
      nrow=3,
      scales = 'free_y',
    )

plt

COVID case rates and COVID hospitalization rates both obey rough trends dominated by the “surge” structure. To some extent, hospitalization rates are a lagged version of cases. This effect appears most exaggerated during the holiday season 2020.

The flu hospitalization rate shows little similarity with either of the COVID signals, save for what appears to be an intensification of hospitalizations during the holiday season 2020. The signal appears to largely be noise. (Since we only have flu hospitalization data starting from late-2020, none of the data follows the classical, pre-COVID flu season trajectory).

It is honestly hard to see anything other than a baseline forecaster (i.e., propagate the previous week forward) doing well in this setting.

plt = data_tall %>% filter (
      signal == 'flu_hosp_prop',
      time_value >= '2021-10-01',
    ) %>% ggplot (
      aes(x=time_value,
          y=value,
          color=geo_value)
    ) + geom_line (
    ) + ggtitle (
      "Flu hospitalization rates, forecast_date >= '2021-10-01'"
    )

plt

Decomposing the distribution of flu hospitalization rate by state, it appears that most states have a distribution centered around 0.02, with a small standard deviation. A few of the states have a wider spread.

Concerningly, some of the states have have all zeros. To be fair, these states are also those with small population. My bet there are still hospitalizations there, but that hospitals are not properly reporting flu-related hospitalizations.

plt_df = data_tall %>% filter (
      signal == 'flu_hosp_prop',
      time_value >= '2021-10-01',
    )
plt_sds = plt_df %>% group_by (
      geo_value,
    ) %>% summarize (
      sd = sd(value),
    ) %>% ungroup (
    ) %>% mutate (
      geo_value_sd = sprintf('%s, sd=%0.3f', geo_value, sd),
    )
plt_df = plt_df %>% inner_join (
      plt_sds,
      by=c('geo_value'),
    )
plt = ggplot (
      plt_df,
      aes(x=value)
    ) + geom_histogram(
    ) + facet_wrap (
      vars(geo_value_sd),
      ncol = 6
    ) + theme (
      axis.title.x = element_text(size=20),
      axis.title.y = element_text(size=20),
      axis.text.x = element_text(size=12),
      axis.text.y = element_text(size=12),
      strip.text.x = element_text(size=16),
    ) + ggtitle (
      "Flu hospitalization rates, forecast_date >= '2021-10-01'"
    )
plt
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Forecasting results

We trained two forecasters, using forecast dates from October 1, 2021 to November 6, 2021. For each forecast date, we predict 5 days ahead to 28 days ahead. Predictions are made for the seven county-level quantiles prescribed by the COVIDhub. (The CDC requests predictions for 23 quantiles). Training was done on the rates scale; one can transform the predictions to the counts scale (as requested by the CDC) post hoc.

Here, we plot mean WIS’s, aggregated across forecast dates and locations, by ahead. We also separately aggregate, excluding locations that only observed zeros during the entirety of the forecasting task.

results = readRDS('results/results.RDS')

states_all_zero = data_tall %>% filter (
      signal == 'flu_hosp_prop',
      time_value >= '2021-10-01',
    ) %>% group_by (
      geo_value,
    ) %>% summarize (
      prop_nonzero = mean(value != 0),
    ) %>% filter (
      #prop_nonzero <= 0.20,
      prop_nonzero == 0,
    )
state_pop = covidcast::state_census %>% tibble (
    ) %>% transmute (
      geo_value = stringr::str_to_lower(ABBR),
      pop = POPESTIMATE2019,
    )

results = results %>% left_join (
      state_pop,
      by='geo_value',
    ) %>% mutate (
      pop_wis = wis * pop,
      pop_ae = ae * pop,
    )

cat(sprintf('States with all zeros: %s.\n',
    paste(states_all_zero$geo_value, collapse=', ')))
## States with all zeros: mt, vi, vt.
mean_results = bind_rows (
      results %>% group_by (
        forecaster,
        ahead,
      ) %>% summarize (
        mean_wis = mean(wis),
        mean_pop_wis = mean(pop_wis),
      ) %>% mutate (
        aggregation = 'all',
      ),
      results %>% filter (
        !(geo_value %in% states_all_zero$geo_value),
      ) %>% group_by (
        forecaster,
        ahead,
      ) %>% summarize (
        mean_wis = mean(wis),
        mean_pop_wis = mean(pop_wis),
      ) %>% mutate (
        aggregation = 'exclude_all_zero',
      ),
    )

plt = mean_results %>% ggplot (
      aes(x=ahead,
          y=mean_wis,
          linetype = aggregation,
          color=forecaster)
    ) + geom_point (
    ) + geom_line (
    ) + geom_vline (
      xintercept=c(5, 12, 19, 26),
      linetype='dotted',
    )

plt

Here, we plot population-weighted mean WIS’s, aggregated across forecast dates and locations, by ahead.

plt = mean_results %>% ggplot (
      aes(x=ahead,
          y=mean_pop_wis,
          linetype = aggregation,
          color=forecaster)
    ) + geom_point (
    ) + geom_line (
    ) + geom_vline (
      xintercept=c(5, 12, 19, 26),
      linetype='dotted',
    )

plt

Diagnostics

Diagnostic plots suggest that in most states (larger population states?), the residuals are moderately well-behaved, but pathological behavior (high leverage points, significant departure from the assumption of Gaussian errors) are visible in some states, especially those of smaller population (count-based model more appropriate here?)

preds = readRDS('predictions/predictions.RDS')
actuals = readRDS('actuals.RDS')

evals = preds %>% inner_join(
    actuals,
    by = c('target_end_date', 'geo_value')
  ) %>% mutate (
    ae = abs(value - actual),
    res = value-actual,
  )

plt = ggplot(evals %>% filter (ahead==12, quantile==0.5),
             aes(sample=res)
  ) + stat_qq_line (
    color='red'
  ) + stat_qq (
    size=0.1,
  ) + facet_wrap (
    vars(geo_value),
    scales='free_y',
    ncol=6,
  ) + ggtitle (
    'Residual qqplot, ahead=12, tau=0.5'
  )

plt

plt = ggplot(evals %>% filter (ahead==12, quantile==0.5),
             aes(value, res)
  ) + geom_point (
    size=0.1,
  ) + geom_hline (
    yintercept=0,
    color='red',
  ) + facet_wrap (
    vars(geo_value),
    scales='free',
    ncol=6,
  ) + ggtitle (
    'Residuals, ahead=12, tau=0.5'
  )

plt

sub_aheads = c(5, 12, 19)

df_list = vector('list', length(forecast_dates))
for (idx in 1:length(forecast_dates)) {
  debug_results = readRDS(sprintf('./debug_results/AR_state_%s.RDS',
                                  forecast_dates[idx]))
  mini_df_list = vector('list', length(sub_aheads))
  for (sidx in 1:length(sub_aheads)) {
  mini_df_list[[sidx]] = tibble(
      forecast_date=forecast_dates[idx],
      coef=0:3,
      value= debug_results$predict_params[[
                which(aheads==sub_aheads[sidx])]]$object$beta[,4],
      ahead=sub_aheads[sidx])
  }
  df_list[[idx]] = bind_rows(mini_df_list)
}
coefs_df = bind_rows(df_list)
coefs_df$ahead = as.factor(coefs_df$ahead)

plt = ggplot(coefs_df,
             aes(x=coef,
                 y=value,
                 color=ahead)
  ) + geom_line (
  ) + facet_wrap (
    vars(forecast_date)
  ) + ggtitle (
    'Coefficient values for predicting the median (tau=0.5)'
  )
plt

Appendix

plt = ggplot(data_tall,
             aes(x=time_value,
                 y=value,
                 color=geo_value),
    ) + geom_line (
    ) + facet_wrap (
      vars(signal),
      nrow=3,
      scales = 'free_y',
    )

plt